介紹完CNN過後,接下來想介紹給各位另外一種神經網路—RNN,這種神經網路和CNN不同之處在於,RNN的資料處理是具有先後順序的,前面所得到的結果會影響後續運算的結果。RNN主要模型分為三種,分別為Simple RNN、LSTM、GRU,差別在於LSTM補強RNN長期記憶的短版,而GRU為LSTM的簡化。接下來我們就要以股票預測的實例來介紹RNN的模型運作。
首先第一步我們需要下載安裝python的twstock套件,這是臺灣股市專用套件,裡面有各種台灣股票的資訊。我們執行以下程式碼進行安裝。
! pip install twstock 
接下來我們要利用twstock套件獲得台塑股票的資訊,這裡以2021年做範例。首先第一次執行時建立csv檔,接著我們每次取三個月的股票資訊做儲存,避免我們的IP被鎖住而無法執行程式。每次執行前修改for迴圈的範圍即可,上下標為月份的範圍。
import csv
import twstock
import os
# if file doesn't exist, create file
filepath = "stock2021.csv"
if not os.path.isfile(filepath):
    # create columns
    title = ["日期", "成交股數", "成交金額", "開盤價", "最高價", "最低價", "收盤價", "漲跌價差", "成交筆數"]
    data = []
    for i in range(1, 4):
        # get the stock through the stock's code
        stock = twstock.Stock("1301")
        stocklist = stock.fetch(2021, i)
    for stock in stocklist:
        # convert datetime object into string
        strdate = stock.date.strftime("%Y-%m-%d")
        li = [strdate, stock.capacity, stock.turnover, stock.open, stock.high, stock.low, stock.close, stock.change, stock.transaction]
        data.append(li)
    # create the csv file
    outputfile = open(filepath, "w", newline = "", encoding = "big5")
    outputwriter = csv.writer(outputfile)
    # output file
    outputwriter.writerow(title)
    for dataline in (data):
        outputwriter.writerow(dataline)
    outputfile.close()
第二次開始執行前修改for迴圈的上下標以及修改檔案輸出模式,從原先的w改為a,也就是從寫入改成增添,這樣就不會覆蓋先前下載過的資訊,而是直接接在後方了。
import csv
import twstock
import os
# if file doesn't exist, create file
filepath = "stock2021.csv"
if os.path.isfile(filepath):
    # create columns
    data = []
    for i in range(10, 13):   # change the upper limit and lower limit every time
        # get the stock through the stock's code
        stock = twstock.Stock("1301")
        stocklist = stock.fetch(2021, i)
    for stock in stocklist:
        # convert datetime object into string
        strdate = stock.date.strftime("%Y-%m-%d")
        li = [strdate, stock.capacity, stock.turnover, stock.open, stock.high, stock.low, stock.close, stock.change, stock.transaction]
        data.append(li)
    # create the csv file
    outputfile = open(filepath, "a", newline = "", encoding = "big5")
    outputwriter = csv.writer(outputfile)
  
    # add title when first add information
    for dataline in (data):
        outputwriter.writerow(dataline)
    outputfile.close()
最後反覆執行就獲得到2021年1月到12月的台塑股票資訊了!